/*
 * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"

#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#include <torch/extension.h>
#endif

#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <mutex>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#if ENABLE_MULTI_DEVICE

#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif

namespace tensorrt_llm::common::nccl_util
{

//==============================================================================
// NCCL Helper - Dynamic Library Loading
//==============================================================================

// Helper class for dynamically loading NCCL symbols (ncclMemAlloc, ncclCommWindowRegister)
// This allows the code to work with NCCL libraries that may or may not have these symbols
class NCCLHelper
{
public:
    static NCCLHelper& getInstance();

    // Dynamic loading function type definition
    using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int);
    using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t);

    // Get function pointer for ncclCommWindowRegister
    ncclCommWindowRegisterFunc getNCCLCommWindowRegister();

    // Get function pointer for ncclMemAlloc
    ncclMemAllocFunc getNCCLMemAlloc();

    // Check if NCCL library is successfully loaded
    bool isLoaded() const;

    NCCLHelper(NCCLHelper const&) = delete;
    NCCLHelper& operator=(NCCLHelper const&) = delete;
    NCCLHelper(NCCLHelper&&) = delete;
    NCCLHelper& operator=(NCCLHelper&&) = delete;

private:
    NCCLHelper();
    ~NCCLHelper();

    void loadNCCLLibrary();
    void* loadLibraryHandle(char const* libName);
    void* getSymbolAddress(void* handle, char const* symbolName);

#ifdef _WIN32
    HMODULE mLibraryHandle;
#else
    void* mLibraryHandle;
#endif

    ncclCommWindowRegisterFunc mNCCLCommWindowRegister;
    ncclMemAllocFunc mNCCLMemAlloc;
    bool mIsLoaded;
};

//==============================================================================
// NCCL Resource Management
//==============================================================================

// Resource cleanup function type. Called before the NCCL communicator is destroyed.
using ResourceCleanupFunc = std::function<void()>;

// Manages resources associated with NCCL communicators. Thread-safe singleton that maintains
// a pool of resources per NCCL comm. Resources are automatically cleaned up when the
// communicator is destroyed.
class NcclCommResourceManager
{
public:
    static NcclCommResourceManager& getInstance() noexcept;

    // Register a resource cleanup function for a specific NCCL communicator.
    // The cleanup function will be called before ncclCommDestroy.
    // Thread-safe: Uses global mutex to serialize all operations.
    void registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName = nullptr);

    // Cleanup all resources associated with a communicator. Called automatically by
    // the shared_ptr deleter before ncclCommDestroy.
    // Thread-safe: Uses global mutex to serialize cleanup operations.
    // Order-preserving: Resources are cleaned up in registration order.
    void cleanupResources(ncclComm_t comm) noexcept;

    // Check if a communicator has registered resources.
    bool hasResources(ncclComm_t comm) const noexcept;

    // Get the number of resources registered for a communicator.
    size_t getResourceCount(ncclComm_t comm) const noexcept;

    NcclCommResourceManager(NcclCommResourceManager const&) = delete;
    NcclCommResourceManager& operator=(NcclCommResourceManager const&) = delete;
    NcclCommResourceManager(NcclCommResourceManager&&) = delete;
    NcclCommResourceManager& operator=(NcclCommResourceManager&&) = delete;

private:
    NcclCommResourceManager() = default;
    ~NcclCommResourceManager() = default;

    using ResourceEntry = std::pair<ResourceCleanupFunc, std::string>;

    mutable std::mutex mMutex;
    std::unordered_map<ncclComm_t, std::vector<ResourceEntry>> mCommResources;
};

// RAII helper to register a resource with a NCCL communicator.
// Automatically registers cleanup function on construction.
template <typename ResourceType>
class NcclCommResource
{
public:
    NcclCommResource(ncclComm_t comm, ResourceType&& resource, std::function<void(ResourceType&)> cleanup,
        char const* debugName = nullptr)
        : mComm(comm)
        , mResource(std::forward<ResourceType>(resource))
        , mCleanup(std::move(cleanup))
        , mRegistered(true)
    {
        // Register with the manager
        NcclCommResourceManager::getInstance().registerResource(
            comm,
            [this]()
            {
                if (mCleanup)
                {
                    mCleanup(mResource);
                }
            },
            debugName);
    }

    ResourceType& get()
    {
        return mResource;
    }

    ResourceType const& get() const
    {
        return mResource;
    }

    NcclCommResource(NcclCommResource const&) = delete;
    NcclCommResource& operator=(NcclCommResource const&) = delete;
    NcclCommResource(NcclCommResource&&) = delete;
    NcclCommResource& operator=(NcclCommResource&&) = delete;

private:
    ncclComm_t mComm;
    ResourceType mResource;
    std::function<void(ResourceType&)> mCleanup;
    bool mRegistered;
};

//==============================================================================
// NCCL Window Buffer Allocation
//==============================================================================

// Represents a buffer with an associated NCCL window
struct NCCLWindowBuffer
{
    void* ptr;           // Device pointer (same as UBBuffer.addr)
    int handle;          // Buffer handle/index (for compatibility with UB interface)
    size_t size;         // Size in bytes
    ncclWindow_t window; // NCCL window handle

    NCCLWindowBuffer(void* p = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
        : ptr(p)
        , handle(h)
        , size(s)
        , window(w)
    {
    }

    [[nodiscard]] bool isValid() const
    {
        return ptr != nullptr && handle >= 0 && size > 0 && window != nullptr;
    }

    [[nodiscard]] bool invalid() const
    {
        return !isValid();
    }

    // Alias for compatibility with UBBuffer interface
    void* addr() const
    {
        return ptr;
    }
};

// Manages NCCL window-registered buffers with pooling and automatic cleanup.
// Buffers are tied to the lifetime of their associated NCCL communicator.
class NCCLWindowAllocator
{
public:
    static NCCLWindowAllocator& getInstance();

    // Request a buffer for the given communicator and size.
    // If an unused buffer of at least the requested size exists for this communicator, it will be reused.
    // Uses best-fit strategy: selects the smallest available buffer that meets the size requirement.
    // Otherwise, a new buffer is allocated and registered.
    NCCLWindowBuffer requestBuffer(ncclComm_t comm, size_t size);

    // Search for a buffer by pointer. Returns an invalid buffer if not found.
    // This matches the UBManager.search_buffer() interface.
    NCCLWindowBuffer searchBuffer(ncclComm_t comm, void* ptr) const;

    // Release a buffer back to the pool for potential reuse
    void releaseBuffer(ncclComm_t comm, void* ptr);

    // Get the window handle for a specific buffer pointer
    ncclWindow_t getWindow(ncclComm_t comm, void* ptr) const;

    // Get the size of a specific buffer pointer
    size_t getSize(ncclComm_t comm, void* ptr) const;

    // Get buffer info by pointer
    NCCLWindowBuffer getBufferInfo(ncclComm_t comm, void* ptr) const;

    // Get the number of buffers allocated for a communicator
    size_t getBufferCount(ncclComm_t comm) const;

    // Get the number of buffers in use for a communicator
    size_t getBufferInUseCount(ncclComm_t comm) const;

    // Check if a communicator is valid (non-null)
    // Note: We don't track cleaned-up comms because NCCL can reuse memory addresses.
    // All non-null comms are considered valid and will be registered when first used.
    bool isCommValid(ncclComm_t comm) const noexcept;

    NCCLWindowAllocator(NCCLWindowAllocator const&) = delete;
    NCCLWindowAllocator& operator=(NCCLWindowAllocator const&) = delete;
    NCCLWindowAllocator(NCCLWindowAllocator&&) = delete;
    NCCLWindowAllocator& operator=(NCCLWindowAllocator&&) = delete;

private:
    NCCLWindowAllocator() = default;
    ~NCCLWindowAllocator() = default;

    // Allocate a new buffer and register it with NCCL as a window
    NCCLWindowBuffer allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle);

    // Search for a buffer by pointer (assumes mMutex is already locked)
    NCCLWindowBuffer searchBufferLocked(ncclComm_t comm, void* ptr) const;

    // Register cleanup function for all buffers associated with a communicator
    void registerBufferCleanup(ncclComm_t comm);

    // Cleanup all buffers for a specific communicator
    void cleanupBuffersForComm(ncclComm_t comm) noexcept;

    struct BufferEntry
    {
        NCCLWindowBuffer buffer;
        bool inUse;
    };

    mutable std::mutex mMutex;
    std::unordered_map<ncclComm_t, std::vector<BufferEntry>> mBufferPool;
    std::unordered_set<ncclComm_t> mRegisteredComms;
};

// RAII wrapper for NCCL window buffers
class ScopedNCCLWindowBuffer
{
public:
    ScopedNCCLWindowBuffer(ncclComm_t comm, size_t size)
        : mComm(comm)
        , mBuffer(NCCLWindowAllocator::getInstance().requestBuffer(comm, size))
    {
    }

    ~ScopedNCCLWindowBuffer()
    {
        if (mBuffer.isValid())
        {
            NCCLWindowAllocator::getInstance().releaseBuffer(mComm, mBuffer.ptr);
        }
    }

    void* getPtr() const
    {
        return mBuffer.ptr;
    }

    size_t getSize() const
    {
        return mBuffer.size;
    }

    ncclWindow_t getWindow() const
    {
        return mBuffer.window;
    }

    NCCLWindowBuffer const& getBuffer() const
    {
        return mBuffer;
    }

    ScopedNCCLWindowBuffer(ScopedNCCLWindowBuffer const&) = delete;
    ScopedNCCLWindowBuffer& operator=(ScopedNCCLWindowBuffer const&) = delete;
    ScopedNCCLWindowBuffer(ScopedNCCLWindowBuffer&&) = delete;
    ScopedNCCLWindowBuffer& operator=(ScopedNCCLWindowBuffer&&) = delete;

private:
    ncclComm_t mComm;
    NCCLWindowBuffer mBuffer;
};

// Creates a PyTorch tensor backed by an NCCL window buffer.
// The tensor will automatically release the buffer back to the pool when destroyed.
// This is analogous to torch_ext::create_userbuffers_tensor() but for NCCLWindowAllocator.
inline std::pair<torch::Tensor, NCCLWindowBuffer> createNCCLWindowTensor(
    ncclComm_t comm, at::IntArrayRef shape, torch::ScalarType dtype)
{
    // Calculate buffer size
    int64_t buffer_size
        = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>()) * torch::elementSize(dtype);

    // Calculate strides
    std::vector<int64_t> strides_vec(shape.size());
    if (!shape.empty())
    {
        strides_vec[shape.size() - 1] = 1;
        for (int64_t i = static_cast<int64_t>(shape.size()) - 1; i >= 1; --i)
        {
            strides_vec[i - 1] = strides_vec[i] * shape[i];
        }
    }

    // Request buffer from allocator
    auto& allocator = NCCLWindowAllocator::getInstance();
    auto buffer = allocator.requestBuffer(comm, buffer_size);

    // Defensive validation: ensure buffer is valid before proceeding
    if (!buffer.isValid())
    {
        std::ostringstream oss;
        oss << "Failed to allocate NCCL window buffer: invalid buffer returned from requestBuffer "
            << "(comm=" << static_cast<void*>(comm) << ", buffer_size=" << buffer_size << ")";
        throw std::runtime_error(oss.str());
    }

    // Create custom deleter that releases the buffer
    auto deleter = [comm, ptr = buffer.ptr](void*) { NCCLWindowAllocator::getInstance().releaseBuffer(comm, ptr); };

    // Create tensor from the buffer
    auto tensor = torch::from_blob(buffer.ptr, shape, strides_vec, deleter, torch::dtype(dtype).device(torch::kCUDA));

    return std::make_pair(tensor, buffer);
}

} // namespace tensorrt_llm::common::nccl_util

#endif // ENABLE_MULTI_DEVICE
